// Copyright (c) 2024 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <limits>
#include <sycl/sycl.hpp>

// #define CINN_SYCL_FP16
// #define CINN_SYCL_BF16
/**
 * \file This file contains all the intrinsics available to be used in SYCL code
 * generated by CodeGen.
 */

extern "C" {

#define MAX_SUBGROUP_SIZE 64
#define MAX_THREADNUM_INGROUP 1024
#define MAX_SUBGROUPNUM_INGROUP \
  ((MAX_THREADNUM_INGROUP - 1) / MAX_SUBGROUP_SIZE + 1)

// *************************************************************** //
// bool unary and binary operator
#define FN_BOOL(func) cinn_sycl_##func##_bool
inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; }
inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; }
inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; }
inline bool FN_BOOL(bitwise_not)(bool a) { return !a; }

// *************************************************************** //
// uint8 unary and binary operator
#define FN_UINT8(func) cinn_sycl_##func##_uint8
inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { return a & b; }
inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { return a | b; }
inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { return a ^ b; }
inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; }
inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) {
  return ((uint8_t)a >> b);
}

// *************************************************************** //
// int8 unary and binary operator
#define FN_INT8(func) cinn_sycl_##func##_int8
inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { return a & b; }
inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { return a | b; }
inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { return a ^ b; }
inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; }
inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) {
  return ((uint8_t)a >> b);
}

// *************************************************************** //
// int16 unary and binary operator
#define FN_INT16(func) cinn_sycl_##func##_int16
inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { return a & b; }
inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { return a | b; }
inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { return a ^ b; }
inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; }
inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) {
  return ((uint16_t)a >> b);
}

// *************************************************************** //
// float32 unary and binary operator
#define FN_FP32(func) cinn_sycl_##func##_fp32
// NOTE Due to function override, we don't need to use type (such as '_fp32') as
// the suffix of function's name.
inline float FN_FP32(sin)(float x) { return sycl::sin(x); }
inline float FN_FP32(cos)(float x) { return sycl::cos(x); }
inline float FN_FP32(tan)(float x) { return sycl::tan(x); }
inline float FN_FP32(sinh)(float x) { return sycl::sinh(x); }
inline float FN_FP32(cosh)(float x) { return sycl::cosh(x); }
inline float FN_FP32(tanh)(float x) { return sycl::tanh(x); }

inline float FN_FP32(asin)(float x) { return sycl::asin(x); }
inline float FN_FP32(acos)(float x) { return sycl::acos(x); }
inline float FN_FP32(atan)(float x) { return sycl::atan(x); }
inline float FN_FP32(asinh)(float x) { return sycl::asinh(x); }
inline float FN_FP32(acosh)(float x) { return sycl::acosh(x); }
inline float FN_FP32(atanh)(float x) { return sycl::atanh(x); }

inline float FN_FP32(ceil)(float x) { return sycl::ceil(x); }
inline float FN_FP32(round)(float x) { return sycl::round(x); }
inline float FN_FP32(trunc)(float x) { return sycl::trunc(x); }
inline float FN_FP32(abs)(float x) { return sycl::fabs(x); }
inline float FN_FP32(floor)(float x) { return sycl::floor(x); }
inline float FN_FP32(log)(float x) { return sycl::log(x); }
inline float FN_FP32(log2)(float x) { return sycl::log2(x); }
inline float FN_FP32(log10)(float x) { return sycl::log10(x); }
inline float FN_FP32(exp)(float x) { return sycl::exp(x); }
inline float FN_FP32(erf)(float x) { return sycl::erf(x); }
inline float FN_FP32(sigmoid)(float x) { return 1.0f / (1.0f + sycl::exp(-x)); }
inline float FN_FP32(sqrt)(float x) { return sycl::sqrt(x); }
inline float FN_FP32(rsqrt)(float x) { return sycl::rsqrt(x); }
inline float FN_FP32(cbrt)(float x) { return sycl::cbrt(x); }

inline bool FN_FP32(isfinite)(float x) { return sycl::isfinite(x); }
inline bool FN_FP32(isinf)(float x) { return sycl::isinf(x); }
inline bool FN_FP32(isnan)(float x) { return sycl::isnan(x); }

inline float FN_FP32(pow)(float a, float b) { return sycl::pow(a, b); }

inline float FN_FP32(mod)(float a, float b) {
  float res = sycl::fmod(a, b);
  if ((res != 0.0f) && ((res < 0.0f) != (b < 0.0f))) res += b;
  return res;
}

// *************************************************************** //
// float64 unary and binary operator
#define FN_FP64(func) cinn_sycl_##func##_fp64

inline double FN_FP64(sin)(double x) { return sycl::sin(x); }
inline double FN_FP64(cos)(double x) { return sycl::cos(x); }
inline double FN_FP64(tan)(double x) { return sycl::tan(x); }
inline double FN_FP64(sinh)(double x) { return sycl::sinh(x); }
inline double FN_FP64(cosh)(double x) { return sycl::cosh(x); }
inline double FN_FP64(tanh)(double x) { return sycl::tanh(x); }

inline double FN_FP64(asin)(double x) { return sycl::asin(x); }
inline double FN_FP64(acos)(double x) { return sycl::acos(x); }
inline double FN_FP64(atan)(double x) { return sycl::atan(x); }
inline double FN_FP64(asinh)(double x) { return sycl::asinh(x); }
inline double FN_FP64(acosh)(double x) { return sycl::acosh(x); }
inline double FN_FP64(atanh)(double x) { return sycl::atanh(x); }

inline double FN_FP64(ceil)(double x) { return sycl::ceil(x); }
inline double FN_FP64(round)(double x) { return sycl::round(x); }
inline double FN_FP64(trunc)(double x) { return sycl::trunc(x); }
inline double FN_FP64(abs)(double x) { return sycl::fabs(x); }
inline double FN_FP64(floor)(double x) { return sycl::floor(x); }
inline double FN_FP64(log)(double x) { return sycl::log(x); }
inline double FN_FP64(log2)(double x) { return sycl::log2(x); }
inline double FN_FP64(log10)(double x) { return sycl::log10(x); }
inline double FN_FP64(exp)(double x) { return sycl::exp(x); }
inline double FN_FP64(erf)(double x) { return sycl::erf(x); }
inline double FN_FP64(sigmoid)(double x) { return 1.0 / (1.0 + sycl::exp(-x)); }
inline double FN_FP64(sqrt)(double x) { return sycl::sqrt(x); }
inline double FN_FP64(rsqrt)(double x) { return sycl::rsqrt(x); }
inline double FN_FP64(cbrt)(double x) { return sycl::cbrt(x); }

inline bool FN_FP64(isfinite)(double x) { return sycl::isfinite(x); }
inline bool FN_FP64(isinf)(double x) { return sycl::isinf(x); }
inline bool FN_FP64(isnan)(double x) { return sycl::isnan(x); }

inline double FN_FP64(pow)(double a, double b) { return sycl::pow(a, b); }
inline double FN_FP64(mod)(double a, double b) {
  double res = sycl::fmod(a, b);
  if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b;
  return res;
}

// *************************************************************** //
// int32 unary and binary operator
#define FN_INT32(func) cinn_sycl_##func##_int32

inline int FN_INT32(pow)(int a, int b) {
  if (a == 0 && b < 0) {
    return -1;
  }
  float res = sycl::pown(
      sycl::vec<int, 1>{a}.convert<float, sycl::rounding_mode::rtn>()[0], b);
  return sycl::vec<float, 1>{res}.convert<int, sycl::rounding_mode::rte>()[0];
}

inline int FN_INT32(left_shift)(int a, int b) { return a << b; }
inline int FN_INT32(right_shift)(int a, int b) { return a >> b; }
inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; }
inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; }
inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; }
inline int FN_INT32(bitwise_not)(int a) { return ~a; }
inline int FN_INT32(clz)(int a) { return sycl::clz(a); }
inline int FN_INT32(popc)(int a) { return sycl::popcount(a); }
inline int FN_INT32(logical_right_shift)(int a, int b) {
  return ((unsigned int)a >> b);
}
inline int FN_INT32(trunc)(int a) { return a; }

inline int FN_INT32(max)(int a, int b) { return sycl::max(a, b); }
inline int FN_INT32(min)(int a, int b) { return sycl::min(a, b); }

inline int FN_INT32(mod)(int a, int b) {
  int res = a % b;
  if ((res != 0) && ((b ^ res) < 0)) res += b;
  return res;
}

// *************************************************************** //

// int64 unary and binary operator
#define FN_INT64(func) cinn_sycl_##func##_int64

inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { return a & b; }
inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { return a | b; }
inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a ^ b; }
inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; }
inline int64_t FN_INT64(clz)(int64_t a) { return sycl::clz(a); }
inline int64_t FN_INT64(popc)(int64_t a) { return sycl::popcount(a); }
inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) {
  return ((uint64_t)a >> b);
}
inline int64_t FN_INT64(trunc)(int64_t a) { return a; }
inline int64_t FN_INT64(mod)(int64_t a, int64_t b) {
  int64_t res = a % b;
  if ((res != 0) && ((b ^ res) < 0)) res += b;
  return res;
}

inline int64_t FN_INT64(pow)(int64_t a, int64_t b) {
  double res = sycl::pown(
      sycl::vec<int64_t, 1>{a}.convert<double, sycl::rounding_mode::rtn>()[0],
      sycl::vec<int64_t, 1>{a}.convert<int, sycl::rounding_mode::rtn>()[0]);
  return sycl::vec<double, 1>{res}
      .convert<int64_t, sycl::rounding_mode::rte>()[0];
}

// *************************************************************** //
// bfloat16 unary and binary operator
#ifdef CINN_SYCL_BF16

#define FN_BF16(func) cinn_sycl_##func##_bf16

inline bfloat16 FN_BF16(ceil)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::ceil(x);
}
inline bfloat16 FN_BF16(floor)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::floor(x);
}

inline bfloat16 FN_BF16(trunc)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::trunc(x);
}

inline bfloat16 FN_BF16(sin)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::sin(x);
}
inline bfloat16 FN_BF16(cos)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::cos(x);
}

inline bfloat16 FN_BF16(exp)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::exp(x);
}
inline bfloat16 FN_BF16(log)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::log(x);
}
inline bfloat16 FN_BF16(log2)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::log2(x);
}
inline bfloat16 FN_BF16(log10)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::log10(x);
}

inline bfloat16 FN_BF16(sqrt)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::sqrt(x);
}
inline bfloat16 FN_BF16(rsqrt)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::rsqrt(x);
}

inline bfloat16 FN_BF16(abs)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::fabs(x);
}

inline bool FN_BF16(isnan)(bfloat16 x) {
  return sycl::ext::oneapi::experimental::isnan(x);
}

inline bfloat16 FN_BF16(sigmoid)(bfloat16 x) {
  return bfloat16(1.0f) /
         (bfloat16(1.0f) + sycl::ext::oneapi::experimental::exp(-x));
}

inline bfloat16 FN_BF16(pow)(bfloat16 a, bfloat16 b) {
  return sycl::ext::oneapi::experimental::pow(a, b);
}

#endif

// *************************************************************** //
// float16 unary and binary operator
#ifdef CINN_SYCL_FP16

#define FN_FP16(func) cinn_sycl_##func##_fp16

inline sycl::half FN_FP16(ceil)(sycl::half x) { return sycl::ceil(x); }
inline sycl::half FN_FP16(floor)(sycl::half x) { return sycl::floor(x); }
inline sycl::half FN_FP16(round)(sycl::half x) { return sycl::round(x); }
inline sycl::half FN_FP16(trunc)(sycl::half x) { return sycl::trunc(x); }

inline sycl::half FN_FP16(sin)(sycl::half x) { return sycl::sin(x); }
inline sycl::half FN_FP16(cos)(sycl::half x) { return sycl::cos(x); }

inline sycl::half FN_FP16(exp)(sycl::half x) { return sycl::exp(x); }
inline sycl::half FN_FP16(log)(sycl::half x) { return sycl::log(x); }
inline sycl::half FN_FP16(log2)(sycl::half x) { return sycl::log2(x); }
inline sycl::half FN_FP16(log10)(sycl::half x) { return sycl::log10(x); }

inline sycl::half FN_FP16(sqrt)(sycl::half x) { return sycl::sqrt(x); }
inline sycl::half FN_FP16(rsqrt)(sycl::half x) { return sycl::rsqrt(x); }

inline sycl::half FN_FP16(cbrt)(sycl::half x) { return sycl::cbrt(x); }

inline sycl::half FN_FP16(abs)(sycl::half x) { return sycl::fabs(x); }

inline bool FN_FP16(isnan)(sycl::half x) { return sycl::isnan(x); }
inline bool FN_FP16(isinf)(sycl::half x) { return sycl::isinf(x); }
inline bool FN_FP16(isfinite)(sycl::half x) { return sycl::isfinite(x); }

inline sycl::half FN_FP16(erf)(sycl::half x) { return sycl::erf(x); }

inline sycl::half FN_FP16(tan)(sycl::half x) { return sycl::tan(x); }
inline sycl::half FN_FP16(sinh)(sycl::half x) { return sycl::sinh(x); }
inline sycl::half FN_FP16(cosh)(sycl::half x) { return sycl::cosh(x); }
inline sycl::half FN_FP16(tanh)(sycl::half x) { return sycl::tanh(x); }
inline sycl::half FN_FP16(asin)(sycl::half x) { return sycl::asin(x); }
inline sycl::half FN_FP16(acos)(sycl::half x) { return sycl::acos(x); }
inline sycl::half FN_FP16(atan)(sycl::half x) { return sycl::atan(x); }
inline sycl::half FN_FP16(asinh)(sycl::half x) { return sycl::asinh(x); }
inline sycl::half FN_FP16(acosh)(sycl::half x) { return sycl::acosh(x); }
inline sycl::half FN_FP16(atanh)(sycl::half x) { return sycl::atanh(x); }

inline sycl::half FN_FP16(sigmoid)(sycl::half x) {
  return static_cast<sycl::half>(1.0f) /
         (static_cast<sycl::half>(1.0f) + sycl::exp(-x));
}

inline sycl::half FN_FP16(mod)(sycl::half a, sycl::half b) {
  sycl::half res = sycl::fmod(a, b);
  if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b;
  return res;
}
inline sycl::half FN_FP16(pow)(sycl::half a, sycl::half b) {
  return sycl::pow(a, b);
}

#endif

// *************************************************************** //
// reduce operator, need `--expt-relaxed-constexpr` option to call std function
// in device kernel
#define EXPAND_REDUCE_INT32_MARCO(MARCO, ...)                               \
  MARCO(sum_int32, 0, int, ##__VA_ARGS__)                                   \
  MARCO(prod_int32, 1, int, ##__VA_ARGS__)                                  \
  MARCO(max_int32, std::numeric_limits<int32_t>::min(), int, ##__VA_ARGS__) \
  MARCO(min_int32, std::numeric_limits<int32_t>::max(), int, ##__VA_ARGS__)

inline int cinn_sum_int32(const int left, const int right) {
  return left + right;
}
inline int cinn_prod_int32(const int left, const int right) {
  return left * right;
}
inline int cinn_max_int32(const int left, const int right) {
  return sycl::max(left, right);
}
inline int cinn_min_int32(const int left, const int right) {
  return sycl::min(left, right);
}

#define EXPAND_REDUCE_INT64_MARCO(MARCO, ...)                                 \
  MARCO(sum_int64, 0, int64_t, ##__VA_ARGS__)                                 \
  MARCO(prod_int64, 1, int64_t, ##__VA_ARGS__)                                \
  MARCO(                                                                      \
      max_int64, std::numeric_limits<int64_t>::min(), int64_t, ##__VA_ARGS__) \
  MARCO(min_int64, std::numeric_limits<int64_t>::max(), int64_t, ##__VA_ARGS__)

inline int64_t cinn_sum_int64(const int64_t left, const int64_t right) {
  return left + right;
}
inline int64_t cinn_prod_int64(const int64_t left, const int64_t right) {
  return left * right;
}
inline int64_t cinn_max_int64(const int64_t left, const int64_t right) {
  return sycl::max(left, right);
}
inline int64_t cinn_min_int64(const int64_t left, const int64_t right) {
  return sycl::min(left, right);
}

#define EXPAND_REDUCE_FP32_MACRO(MACRO, ...)                               \
  MACRO(sum_fp32, 0.0f, float, ##__VA_ARGS__)                              \
  MACRO(prod_fp32, 1.0f, float, ##__VA_ARGS__)                             \
  MACRO(max_fp32, std::numeric_limits<float>::min(), float, ##__VA_ARGS__) \
  MACRO(min_fp32, std::numeric_limits<float>::max(), float, ##__VA_ARGS__)

inline float cinn_sum_fp32(const float left, const float right) {
  return left + right;
}
inline float cinn_prod_fp32(const float left, const float right) {
  return left * right;
}
inline float cinn_max_fp32(const float left, const float right) {
  return sycl::fmax(left, right);
}
inline float cinn_min_fp32(const float left, const float right) {
  return sycl::fmin(left, right);
}

#ifdef CINN_SYCL_BF16

#define EXPAND_REDUCE_BF16_MACRO(MACRO, ...)                                   \
  MACRO(sum_bf16, bfloat16(0.0), bfloat16, ##__VA_ARGS__)                      \
  MACRO(prod_bf16, bfloat16(1.0), bfloat16, ##__VA_ARGS__)                     \
  MACRO(                                                                       \
      max_bf16, std::numeric_limits<bfloat16>::min(), bfloat16, ##__VA_ARGS__) \
  MACRO(min_bf16, std::numeric_limits<bfloat16>::max(), bfloat16, ##__VA_ARGS__)

inline bfloat16 cinn_sum_bf16(const bfloat16 left, const bfloat16 right) {
  return left + right;
}
inline bfloat16 cinn_prod_bf16(const bfloat16 left, const bfloat16 right) {
  return left * right;
}
inline bfloat16 cinn_max_bf16(const bfloat16 left, const bfloat16 right) {
  return sycl::ext::oneapi::experimental::max(left, right);
}
inline bfloat16 cinn_min_bf16(const bfloat16 left, const bfloat16 right) {
  return sycl::ext::oneapi::experimental::min(left, right);
}
#endif

#ifdef CINN_SYCL_FP16

#define EXPAND_REDUCE_FP16_MACRO(MACRO, ...)                   \
  MACRO(sum_fp16, sycl::half(0.0), sycl::half, ##__VA_ARGS__)  \
  MACRO(prod_fp16, sycl::half(1.0), sycl::half, ##__VA_ARGS__) \
  MACRO(max_fp16,                                              \
        std::numeric_limits<sycl::half>::min(),                \
        sycl::half,                                            \
        ##__VA_ARGS__)                                         \
  MACRO(min_fp16,                                              \
        std::numeric_limits<sycl::half>::max(),                \
        sycl::half,                                            \
        ##__VA_ARGS__)

inline sycl::half cinn_sum_fp16(const sycl::half left, const sycl::half right) {
  return left + right;
}
inline sycl::half cinn_prod_fp16(const sycl::half left,
                                 const sycl::half right) {
  return left * right;
}
inline sycl::half cinn_max_fp16(const sycl::half left, const sycl::half right) {
  return sycl::fmax(left, right);
}
inline sycl::half cinn_min_fp16(const sycl::half left, const sycl::half right) {
  return sycl::fmin(left, right);
}
#endif

#define EXPAND_REDUCE_FP64_MACRO(MACRO, ...)                                 \
  MACRO(sum_fp64, 0.0, double, ##__VA_ARGS__)                                \
  MACRO(prod_fp64, 1.0, double, ##__VA_ARGS__)                               \
  MACRO(max_fp64, std::numeric_limits<double>::min(), double, ##__VA_ARGS__) \
  MACRO(min_fp64, std::numeric_limits<double>::max(), double, ##__VA_ARGS__)

inline double cinn_sum_fp64(const double left, const double right) {
  return left + right;
}
inline double cinn_prod_fp64(const double left, const double right) {
  return left * right;
}
inline double cinn_max_fp64(const double left, const double right) {
  return sycl::fmax(left, right);
}
inline double cinn_min_fp64(const double left, const double right) {
  return sycl::fmin(left, right);
}

#define EXPAND_REDUCE_BOOL_MACRO(MACRO, ...) \
  MACRO(all, true, bool, ##__VA_ARGS__)      \
  MACRO(any, false, bool, ##__VA_ARGS__)

inline bool cinn_all(const bool left, const bool right) {
  return left && right;
}
inline bool cinn_any(const bool left, const bool right) {
  return left || right;
}

#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE)     \
  inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal(                     \
      const DTYPE value, const sycl::nd_item<3> &item_ct1) {                   \
    DTYPE tmp_val = value, shfl_res;                                           \
    unsigned int subgroup_size =                                               \
        item_ct1.get_sub_group().get_local_range()[0];                         \
    unsigned int threadId_in_subgroup =                                        \
        item_ct1.get_sub_group().get_local_id()[0];                            \
    if (subgroup_size < MAX_SUBGROUP_SIZE) {                                   \
      for (unsigned int offset = MAX_SUBGROUP_SIZE / 2; offset >= 1;           \
           offset /= 2) {                                                      \
        shfl_res =                                                             \
            sycl::shift_group_left(item_ct1.get_sub_group(), tmp_val, offset); \
        tmp_val =                                                              \
            cinn_##REDUCE_TYPE(tmp_val,                                        \
                               (threadId_in_subgroup + offset) < subgroup_size \
                                   ? shfl_res                                  \
                                   : (DTYPE)(INITIAL_VALUE));                  \
      }                                                                        \
      tmp_val = sycl::select_from_group(item_ct1.get_sub_group(), tmp_val, 0); \
      return tmp_val;                                                          \
    } else {                                                                   \
      for (unsigned int offset = MAX_SUBGROUP_SIZE / 2; offset >= 1;           \
           offset /= 2) {                                                      \
        tmp_val = cinn_##REDUCE_TYPE(                                          \
            tmp_val,                                                           \
            sycl::shift_group_left(                                            \
                item_ct1.get_sub_group(), tmp_val, offset));                   \
      }                                                                        \
      return tmp_val;                                                          \
    }                                                                          \
  }

EXPAND_REDUCE_INT32_MARCO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_INT64_MARCO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_FP64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_BOOL_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)

#ifdef CINN_SYCL_BF16
EXPAND_REDUCE_BF16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
#endif

#ifdef CINN_SYCL_FP16
EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
#endif

#undef CINN_WARP_SHUFFLE_INTERNAL_IMPL

#define CINN_WARP_REDUCE_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE)          \
  inline DTYPE cinn_warp_reduce_##REDUCE_TYPE(                            \
      const DTYPE *buf,                                                   \
      int offset,                                                         \
      int extend,                                                         \
      const sycl::nd_item<3> &item_ct1) {                                 \
    DTYPE tmp_val = (DTYPE)(INITIAL_VALUE);                               \
    unsigned int subgroup_size =                                          \
        item_ct1.get_sub_group().get_local_range()[0];                    \
    for (int i = item_ct1.get_sub_group().get_local_id()[0]; i < extend;  \
         i += subgroup_size) {                                            \
      tmp_val = cinn_##REDUCE_TYPE(tmp_val, buf[offset + i]);             \
    }                                                                     \
    return cinn_warp_shuffle_##REDUCE_TYPE##_internal(tmp_val, item_ct1); \
  }

EXPAND_REDUCE_INT32_MARCO(CINN_WARP_REDUCE_IMPL)
EXPAND_REDUCE_INT64_MARCO(CINN_WARP_REDUCE_IMPL)
EXPAND_REDUCE_FP32_MACRO(CINN_WARP_REDUCE_IMPL)
EXPAND_REDUCE_FP64_MACRO(CINN_WARP_REDUCE_IMPL)
EXPAND_REDUCE_BOOL_MACRO(CINN_WARP_REDUCE_IMPL)

#ifdef CINN_SYCL_BF16
EXPAND_REDUCE_BF16_MACRO(CINN_WARP_REDUCE_IMPL)
#endif

#ifdef CINN_SYCL_FP16
EXPAND_REDUCE_FP16_MACRO(CINN_WARP_REDUCE_IMPL)
#endif

#undef CINN_WARP_REDUCE_IMPL

inline float cinn_warp_reduce_avg_fp32(const float *buf,
                                       int offset,
                                       int extend,
                                       const sycl::nd_item<3> &item_ct1) {
  return cinn_warp_reduce_sum_fp32(buf, offset, extend, item_ct1) / extend;
}

/*
DPCT1065:41: Consider replacing sycl::nd_item::barrier() with
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
performance if there is no access to global memory.
*/
#define CINN_BLOCK_REDUCE_INTERNAL_IMPL(                                 \
    TYPE, value, init_value, cinn_warp_shuffle_internal)                 \
  unsigned int subgroup_id = item_ct1.get_sub_group().get_group_id()[0]; \
  auto tmp = *sycl::group_local_memory<TYPE[MAX_SUBGROUPNUM_INGROUP]>(   \
      item_ct1.get_group());                                             \
  if (subgroup_id == 0) {                                                \
    tmp[item_ct1.get_local_id(2)] = init_value;                          \
  }                                                                      \
  TYPE tmp_val = cinn_warp_shuffle_internal(value, item_ct1);            \
  if (item_ct1.get_sub_group().get_local_range()[0] == 1) {              \
    return tmp_val;                                                      \
  }                                                                      \
  item_ct1.barrier(sycl::access::fence_space::local_space);              \
  if (item_ct1.get_sub_group().leader()) {                               \
    tmp[subgroup_id] = tmp_val;                                          \
  }                                                                      \
  item_ct1.barrier(sycl::access::fence_space::local_space);              \
  if (subgroup_id == 0) {                                                \
    tmp_val = tmp[item_ct1.get_local_id(2)];                             \
    tmp_val = cinn_warp_shuffle_internal(tmp_val, item_ct1);             \
    if (item_ct1.get_local_id(2) == 0) {                                 \
      tmp[0] = tmp_val;                                                  \
    }                                                                    \
  }                                                                      \
  item_ct1.barrier(sycl::access::fence_space::local_space);              \
  return tmp[0];

#define CINN_BLOCK_REDUCE_INTERNAL_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
  inline DTYPE cinn_block_reduce_##REDUCE_TYPE##_internal(                  \
      const DTYPE value, const sycl::nd_item<3> &item_ct1) {                \
    CINN_BLOCK_REDUCE_INTERNAL_IMPL(                                        \
        DTYPE,                                                              \
        value,                                                              \
        (DTYPE)(INITIAL_VALUE),                                             \
        cinn_warp_shuffle_##REDUCE_TYPE##_internal);                        \
  }

EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
EXPAND_REDUCE_INT64_MARCO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)

#ifdef CINN_SYCL_BF16
EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
#endif

#ifdef CINN_SYCL_FP16
EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
#endif
#undef CINN_BLOCK_REDUCE_INTERNAL_IMPL
#undef CINN_BLOCK_REDUCE_INTERNAL_MACRO

#define CINN_BLOCK_REDUCE_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE)         \
  inline DTYPE cinn_block_reduce_##REDUCE_TYPE(                           \
      const DTYPE *buf,                                                   \
      int offset,                                                         \
      int extend,                                                         \
      const sycl::nd_item<3> &item_ct1) {                                 \
    DTYPE tmp_val = (DTYPE)(INITIAL_VALUE);                               \
    for (int i = item_ct1.get_local_id(2); i < extend;                    \
         i += item_ct1.get_local_range(2)) {                              \
      tmp_val = cinn_##REDUCE_TYPE(tmp_val, buf[offset + i]);             \
    }                                                                     \
    return cinn_block_reduce_##REDUCE_TYPE##_internal(tmp_val, item_ct1); \
  }

EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_IMPL)
EXPAND_REDUCE_INT64_MARCO(CINN_BLOCK_REDUCE_IMPL)
EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_IMPL)
EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_IMPL)
EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_IMPL)

#ifdef CINN_SYCL_BF16
EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_IMPL)
#endif

#ifdef CINN_SYCL_FP16
EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_IMPL)
#endif

#undef CINN_BLOCK_REDUCE_IMPL

#undef EXPAND_REDUCE_INT32_MARCO
#undef EXPAND_REDUCE_INT64_MARCO
#undef EXPAND_REDUCE_FP32_MACRO
#undef EXPAND_REDUCE_FP64_MACRO
#undef EXPAND_REDUCE_BOOL_MACRO

#ifdef CINN_SYCL_BF16
#undef EXPAND_REDUCE_BF16_MACRO
#endif

#ifdef CINN_SYCL_FP16
#undef EXPAND_REDUCE_FP16_MACRO
#endif

// *************************************************************** //
// other function
#define __cinn_sycl_find_kernel(buf, size, num, begin, stride)           \
  do {                                                                   \
    for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \
      if (buf[i] == num) return (i - begin) / stride;                    \
    }                                                                    \
    return -1;                                                           \
  } while (0)

inline int cinn_sycl_find_int(const int *buf, int size, int num) {
  __cinn_sycl_find_kernel(buf, size, num, 0, 1);
}

inline int cinn_sycl_find_float(const float *buf, int size, float num) {
  __cinn_sycl_find_kernel(buf, size, num, 0, 1);
}

inline int cinn_sycl_find_int_nd(
    const int *buf, int size, int num, int begin, int stride) {
  __cinn_sycl_find_kernel(buf, size, num, begin, stride);
}

inline int cinn_sycl_find_float_nd(
    const float *buf, int size, float num, int begin, int stride) {
  __cinn_sycl_find_kernel(buf, size, num, begin, stride);
}

#undef __cinn_sycl_find_kernel

inline int cinn_sycl_next_smallest_int32(
    int *buf, int size, int num, int begin, int stride) {
  int id = -1;
  for (int i = begin; i < begin + size * stride; i += stride) {
    if (id == -1 || buf[i] < buf[id]) {
      id = i;
    }
  }
  if (id != -1) {
    buf[id] = std::numeric_limits<int32_t>::max();
    return (id - begin) / stride;
  }
  return -1;
}

#define __cinn_sycl_find_from_kernel(buf, size, num, begin) \
  do {                                                      \
    for (int i = begin; i < size; ++i) {                    \
      if (buf[i] == num) return i;                          \
    }                                                       \
    return -1;                                              \
  } while (0)

inline int cinn_sycl_find_int_from(const int *buf,
                                   int size,
                                   int num,
                                   int begin) {
  __cinn_sycl_find_from_kernel(buf, size, num, begin);
}

inline int cinn_sycl_find_float_from(const float *buf,
                                     int size,
                                     float num,
                                     int begin) {
  __cinn_sycl_find_from_kernel(buf, size, num, begin);
}

#undef __cinn_sycl_find_from_kernel

#define CINN_NVGPU_LT_NUM(TYPE_SUFFIX, TYPE)                               \
  inline int cinn_sycl_lt_num_##TYPE_SUFFIX(const TYPE *buf,               \
                                            const int size,                \
                                            const TYPE num,                \
                                            const int offset,              \
                                            const int stride) {            \
    int out = 0;                                                           \
    for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \
      if (buf[i] < num) out++;                                             \
    }                                                                      \
    return out;                                                            \
  }

CINN_NVGPU_LT_NUM(fp32, float)
CINN_NVGPU_LT_NUM(fp64, double)
CINN_NVGPU_LT_NUM(uint8, uint8_t)
CINN_NVGPU_LT_NUM(int16, int16_t)
CINN_NVGPU_LT_NUM(int32, int)
CINN_NVGPU_LT_NUM(int64, int64_t)
#ifdef CINN_SYCL_FP16
CINN_NVGPU_LT_NUM(fp16, sycl::half)
#endif

#undef CINN_NVGPU_LT_NUM

#define CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE)                               \
  inline int cinn_sycl_gt_num_##TYPE_SUFFIX(const TYPE *buf,               \
                                            const int size,                \
                                            const TYPE num,                \
                                            const int offset,              \
                                            const int stride) {            \
    int out = 0;                                                           \
    for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \
      if (buf[i] > num) out++;                                             \
    }                                                                      \
    return out;                                                            \
  }

CINN_NVGPU_GT_NUM(fp32, float)
CINN_NVGPU_GT_NUM(fp64, double)
CINN_NVGPU_GT_NUM(uint8, uint8_t)
CINN_NVGPU_GT_NUM(int16, int16_t)
CINN_NVGPU_GT_NUM(int32, int)
CINN_NVGPU_GT_NUM(int64, int64_t)
#ifdef CINN_SYCL_FP16
CINN_NVGPU_GT_NUM(fp16, sycl::half)
#endif

#undef CINN_NVGPU_GT_NUM

#define CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE)                               \
  inline TYPE cinn_sycl_index_add_##TYPE_SUFFIX(const TYPE x,                 \
                                                const int axis_indice,        \
                                                const TYPE *y,                \
                                                const int offset,             \
                                                const int stride,             \
                                                const int *index,             \
                                                const int index_size) {       \
    TYPE res = x;                                                             \
    int idx = -1;                                                             \
    do {                                                                      \
      idx = cinn_sycl_find_int_from(index, index_size, axis_indice, idx + 1); \
      if (idx >= 0) {                                                         \
        res += y[offset + idx * stride];                                      \
      }                                                                       \
    } while (idx != -1);                                                      \
    return res;                                                               \
  }

CINN_NVGPU_INDEX_ADD(bool, bool)
CINN_NVGPU_INDEX_ADD(int8, int8_t)
CINN_NVGPU_INDEX_ADD(int32, int32_t)
CINN_NVGPU_INDEX_ADD(int64, int64_t)
CINN_NVGPU_INDEX_ADD(fp32, float)
CINN_NVGPU_INDEX_ADD(fp64, double)
#ifdef CINN_SYCL_FP16
CINN_NVGPU_INDEX_ADD(fp16, sycl::half)
#endif

#undef CINN_CUDA_INDEX_ADD

int cinn_sycl_resize_bilinear(const int *buf,
                              const int c_size,
                              const int in_h,
                              const int in_w,
                              const int out_h,
                              const int out_w,
                              const int n,
                              const int c,
                              const int y,
                              const int x) {
  float scale_y = static_cast<float>(in_h) / out_h;
  float scale_x = static_cast<float>(in_w) / out_w;
  float in_y = (y + 0.5F) * scale_y - 0.5F;
  float in_x = (x + 0.5F) * scale_x - 0.5F;
  int in_y_int = static_cast<int>(FN_FP32(floor)(in_y));
  int in_x_int = static_cast<int>(FN_FP32(floor)(in_x));
  float y_lerp = in_y - in_y_int;
  float x_lerp = in_x - in_x_int;
  float p[2][2];

  for (int i = 0; i < 2; ++i) {
    for (int j = 0; j < 2; ++j) {
      int near_y = in_y_int + i;
      int near_x = in_x_int + j;
      near_y = FN_INT32(max)(FN_INT32(min)(near_y, in_h - 1), 0);
      near_x = FN_INT32(max)(FN_INT32(min)(near_x, in_w - 1), 0);
      p[i][j] = buf[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w +
                    near_x];
    }
  }

  float top = p[0][0] * (1.0F - x_lerp) + p[0][1] * x_lerp;
  float bottom = p[1][0] * (1.0F - x_lerp) + p[1][1] * x_lerp;
  float value = top * (1.0F - y_lerp) + bottom * y_lerp;
  return value;
}

/*
DPCT1110:40: The total declared local variable size in device function
cinn_sycl_resize_bicubic exceeds 128 bytes and may cause high register pressure.
Consult with your hardware vendor to find the total register size available and
adjust the code, or use smaller sub-group size to avoid high register pressure.
*/
int cinn_sycl_resize_bicubic(const int *buf,
                             const int c_size,
                             const int in_h,
                             const int in_w,
                             const int out_h,
                             const int out_w,
                             const int n,
                             const int c,
                             const int y,
                             const int x) {
  float scale_y = static_cast<float>(in_h) / out_h;
  float scale_x = static_cast<float>(in_w) / out_w;
  float in_y = (y + 0.5F) * scale_y - 0.5F;
  float in_x = (x + 0.5F) * scale_x - 0.5F;
  int in_y_int = static_cast<int>(cinn_sycl_floor_fp32(in_y));
  int in_x_int = static_cast<int>(cinn_sycl_floor_fp32(in_x));
  float y_fract = in_y - cinn_sycl_floor_fp32(in_y);
  float x_fract = in_x - cinn_sycl_floor_fp32(in_x);
  float p[4][4];

  for (int i = 0; i < 4; ++i) {
    for (int j = 0; j < 4; ++j) {
      int near_y = in_y_int + i - 1;
      int near_x = in_x_int + j - 1;
      near_y = FN_INT32(max)(FN_INT32(min)(near_y, in_h - 1), 0);
      near_x = FN_INT32(max)(FN_INT32(min)(near_x, in_w - 1), 0);
      p[i][j] = buf[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w +
                    near_x];
    }
  }

  float alpha = -0.5F;
  float w[2][4];

  for (int i = 0; i < 2; ++i) {
    float t = (i == 0 ? x_fract : y_fract);
    float t2 = t * t;
    float t3 = t * t * t;
    w[i][0] = alpha * (t3 - 2 * t2 + t);
    w[i][1] = (alpha + 2) * t3 - (3 + alpha) * t2 + 1;
    w[i][2] = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t;
    w[i][3] = -alpha * t3 + alpha * t2;
  }

  float col[4];

  for (int i = 0; i < 4; ++i) {
    col[i] = 0.0F;
    for (int j = 0; j < 4; ++j) {
      col[i] += p[i][j] * w[0][j];
    }
  }

  float value = 0.0F;

  for (int i = 0; i < 4; ++i) {
    value += col[i] * w[1][i];
  }

  return value;
}

// *************************************************************** //
// end of macro undef
#undef MAX_SUBGROUP_SIZE
#undef MAX_THREADNUM_INGROUP
#undef MAX_SUBGROUPNUM_INGROUP
#undef FN_BOOL
#undef FN_UINT8
#undef FN_INT8
#undef FN_INT16
#undef FN_FP32
#undef FN_FP64
#undef FN_INT32
#undef FN_INT64

#ifdef CINN_SYCL_BF16
#undef FN_BF16
#endif
#ifdef CINN_SYCL_FP16
#undef FN_FP16
#endif
}
